In [1]:
epochs = 15
Traditionally, PySyft has been used to facilitate federated learning. However, we can also leverage the tools included in this framework to implement distributed neural networks.
The training of a neural network (NN) is 'split' accross one or more hosts. Each model segment is a self contained NN that feeds into the segment in front. In this example Alice has unlabeled training data and the bottom of the network whereas Bob has the corresponding labels and the top of the network. The image below shows this training process where Bob has all the labels and there are multiple alices with X data [1]. Once $Alice_1$ has trained she sends a copy of her trained bottom model to the next Alice. This continues until $Alice_n$ has trained.
In this case, both parties can train the model without knowing each others data or full details of the model. When Alice is finished training, she passes it to the next person with data.
The SplitNN has been shown to provide a dramatic reduction to the computational burden of training while maintaining higher accuracies when training over large number of clients [2]. In the figure below, the Blue line denotes distributed deep learning using SplitNN, red line indicate federated learning (FL) and green line indicates Large Batch Stochastic Gradient Descent (LBSGD).
Table 1 shows computational resources consumed when training CIFAR 10 over VGG. Theses are a fraction of the resources of FL and LBSGD. Table 2 shows the bandwith usage when training CIFAR 100 over ResNet. Federated learning is less bandwidth intensive with fewer than 100 clients. However, the SplitNN outperforms other approaches as the number of clients grow[ 2].
This tutorial demonstrates a basic example of SplitNN which
Authors:
In [2]:
import torch
from torchvision import datasets, transforms
from torch import nn, optim
import syft as sy
hook = sy.TorchHook(torch)
In [3]:
# Data preprocessing
transform = transforms.Compose([transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,)),
])
trainset = datasets.MNIST('mnist', download=True, train=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)
In [4]:
torch.manual_seed(0)
# Define our model segments
input_size = 784
hidden_sizes = [128, 640]
output_size = 10
models = [
nn.Sequential(
nn.Linear(input_size, hidden_sizes[0]),
nn.ReLU(),
nn.Linear(hidden_sizes[0], hidden_sizes[1]),
nn.ReLU(),
),
nn.Sequential(
nn.Linear(hidden_sizes[1], output_size),
nn.LogSoftmax(dim=1)
)
]
# Create optimisers for each segment and link to their segment
optimizers = [
optim.SGD(model.parameters(), lr=0.03,)
for model in models
]
# create some workers
alice = sy.VirtualWorker(hook, id="alice")
bob = sy.VirtualWorker(hook, id="bob")
workers = alice, bob
# Send Model Segments to starting locations
model_locations = [alice, bob]
for model, location in zip(models, model_locations):
model.send(location)
In [5]:
def train(x, target, models, optimizers):
# Training Logic
#1) erase previous gradients (if they exist)
for opt in optimizers:
opt.zero_grad()
#2) make a prediction
a = models[0](x)
#3) break the computation graph link, and send the activation signal to the next model
remote_a = a.move(models[1].location, requires_grad=True)
#4) make prediction on next model using recieved signal
pred = models[1](remote_a)
#5) calculate how much we missed
criterion = nn.NLLLoss()
loss = criterion(pred, target)
#6) figure out which weights caused us to miss
loss.backward()
# 7) send gradient of the recieved activation signal to the model behind
# grad_a = remote_a.grad.copy().move(models[0].location)
# 8) backpropagate on bottom model given this gradient
# a.backward(grad_a)
#9) change the weights
for opt in optimizers:
opt.step()
#10) print our progress
return loss.detach().get()
In [6]:
for i in range(epochs):
running_loss = 0
for images, labels in trainloader:
images = images.send(alice)
images = images.view(images.shape[0], -1)
labels = labels.send(bob)
loss = train(images, labels, models, optimizers)
running_loss += loss
else:
print("Epoch {} - Training loss: {}".format(i, running_loss/len(trainloader)))
In [ ]: